import numpy as np
from sklearn.cluster import MiniBatchKMeans
from sklearn.metrics import silhouette_score
from common.utils.mlmodel_util import preprocess
from common.utils.mlmodel_util import get_model_info_sklearn
from common.utils.model_util import load_model
from common.utils.format_util import dup_name_handler


def predict(dataframe, para_list, output):
    # --------------------load the model-----------------
    model_save_path, _, feature_list = get_model_info_sklearn(para_list["model_id"])
    mlmodel = load_model(model_save_path)

    # --------------------make predictions----------------
    X = preprocess(dataframe, para_list, feature_list)
    pred = mlmodel.predict(X)

    output_cols = "_cluster_id_"
    output_cols = dup_name_handler(output_cols, para_list["feature_col"])
    output["result"]["output_params"]["output_cols"] = output_cols
    dataframe[output_cols] = pred
    dataframe[output_cols] = dataframe[output_cols].astype("int")
    return dataframe


def train(dataframe, para_list, record):
    # --------------------prepare data------------------
    feature_list = []
    X = preprocess(dataframe, para_list, feature_list)

    # ---------------------model fit---------------------
    k = para_list["k"]
    mlmodel = MiniBatchKMeans(n_clusters=k, max_iter=para_list["max_iter"], batch_size=10000).fit(X)
    pred = mlmodel.labels_

    # ---------------------get metrics---------------
    if len(X) < 500000: #50万数据量下计算轮廓系数花费了40min，大于50万数据量将不再计算轮廓系数
        silhouette = silhouette_score(X, pred)
    else:
        raise ValueError("当数据量大于500000时，不计算轮廓系数，请选择数据量小于500000的数据集")

    # ---------------------record info----------------------
    other_info = {
        "feature_list": feature_list,
        "feature_col": para_list["feature_col"],
        "silhouette": silhouette,
        "cluster_centers": mlmodel.cluster_centers_.tolist(),
        "cluster_sizes": [len(np.where(mlmodel.labels_ == i)[0]) for i in range(k)],
        "numIter": mlmodel.n_iter_
    }
    record["other_info"] = other_info

    output_cols = "_cluster_id_"
    # output["result"]["output_params"]["output_cols"] = output_cols
    dataframe[output_cols] = pred
    dataframe[output_cols] = dataframe[output_cols].astype("int")

    return mlmodel, dataframe
